{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "import torch\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.models.detection import FasterRCNN\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from my_py_toolkit.file.file_toolkit import *\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 超参数\n",
    "lr = 1e-3\n",
    "batch_size = 2\n",
    "data_dir = 'F:/VOS/ubutu-vm/resources/datasets/work/train'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = [p for p in model.parameters() if p.requires_grad]\n",
    "opt = Adam(params, lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataloader\n",
    "# prerpocess data\n",
    "images = []\n",
    "boxes = []\n",
    "labels = []\n",
    "paths = []\n",
    "for name in ['gen_train', 'some_true']:\n",
    "    paths.extend(get_file_paths(f'{data_dir}/{name}'))\n",
    "boxes_info = readjson(f'{data_dir}/fake_box.json')\n",
    "for p in paths:\n",
    "    img = cv2.imread(p) / 255\n",
    "    img = cv2.resize(img, (256, 256))\n",
    "    img = torch.tensor(img).permute(2, 0, 1).to(torch.float32)\n",
    "    b_info = boxes_info.get(get_file_name(p), [0, 1, 0, 1, 0])\n",
    "    b_info[:4] = [b_info[0], b_info[2], b_info[1], b_info[3]]\n",
    "    box = torch.tensor(b_info[:4])\n",
    "    lable = torch.tensor(b_info[4])\n",
    "    images.append(img)\n",
    "    boxes.append(box)\n",
    "    labels.append(lable)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(TensorDataset(torch.stack(images), torch.stack(boxes), torch.stack(labels)), batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/98 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "# train\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 0 loss 0.12573187053203583:  49%|████▉     | 48/98 [35:07<36:35, 43.90s/it]\n",
      "epoch 0 loss 0.1702367216348648: 100%|██████████| 98/98 [35:42<00:00, 21.87s/it] \n",
      "epoch 1 loss 0.6329671144485474: 100%|██████████| 98/98 [31:17<00:00, 19.16s/it] \n",
      "epoch 2 loss 0.736222505569458: 100%|██████████| 98/98 [27:45<00:00, 16.99s/it]  \n",
      "epoch 3 loss 0.32683125138282776: 100%|██████████| 98/98 [25:45<00:00, 15.77s/it]\n",
      "epoch 4 loss 0.06248734891414642: 100%|██████████| 98/98 [26:10<00:00, 16.02s/it]\n",
      "epoch 5 loss 0.10154791176319122: 100%|██████████| 98/98 [28:46<00:00, 17.62s/it] \n",
      "epoch 6 loss 0.7781590819358826: 100%|██████████| 98/98 [25:20<00:00, 15.51s/it] \n",
      "epoch 7 loss 16.679574966430664: 100%|██████████| 98/98 [25:32<00:00, 15.63s/it]\n",
      "epoch 8 loss 0.6923300623893738: 100%|██████████| 98/98 [27:05<00:00, 16.59s/it]\n",
      "epoch 9 loss 84.7992935180664:   7%|▋         | 7/98 [02:15<31:05, 20.50s/it]  "
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-27-64fb820d1472>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      9\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m \u001b[1;32min\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m             \u001b[0mloss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m         \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     12\u001b[0m         \u001b[0mopt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m         \u001b[0mtrain_bar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf'epoch {epoch} loss {loss.item()}'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    305\u001b[0m                 \u001b[0mcreate_graph\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    306\u001b[0m                 inputs=inputs)\n\u001b[1;32m--> 307\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    308\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    154\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m    155\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 156\u001b[1;33m         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    157\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    158\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "for epoch in range(10):\n",
    "    train_bar = tqdm(range(len(dataloader)))\n",
    "    for i, (imgs, boxes, labels) in enumerate(dataloader):\n",
    "        opt.zero_grad()\n",
    "        loss = 0\n",
    "        output = model(imgs, [{'boxes': boxes[i][None], 'labels': labels[i][None]} \n",
    "                      for i in range(batch_size)])\n",
    "        for k, v in output.items():\n",
    "            loss += v\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        train_bar.set_description(f'epoch {epoch} loss {loss.item()}')\n",
    "        train_bar.update(1)\n",
    "    torch.save(model.state_dict(), f'./model/model_{epoch}.pth')\n",
    "    train_bar.close()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FasterRCNN(\n",
       "  (transform): GeneralizedRCNNTransform(\n",
       "      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
       "      Resize(min_size=(800,), max_size=1333, mode='bilinear')\n",
       "  )\n",
       "  (backbone): BackboneWithFPN(\n",
       "    (body): IntermediateLayerGetter(\n",
       "      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "      (bn1): FrozenBatchNorm2d(64, eps=0.0)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "      (layer1): Sequential(\n",
       "        (0): Bottleneck(\n",
       "          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "          (downsample): Sequential(\n",
       "            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          )\n",
       "        )\n",
       "        (1): Bottleneck(\n",
       "          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (2): Bottleneck(\n",
       "          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(64, eps=0.0)\n",
       "          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (layer2): Sequential(\n",
       "        (0): Bottleneck(\n",
       "          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "          (downsample): Sequential(\n",
       "            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "            (1): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          )\n",
       "        )\n",
       "        (1): Bottleneck(\n",
       "          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (2): Bottleneck(\n",
       "          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (3): Bottleneck(\n",
       "          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(128, eps=0.0)\n",
       "          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (layer3): Sequential(\n",
       "        (0): Bottleneck(\n",
       "          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "          (downsample): Sequential(\n",
       "            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "            (1): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          )\n",
       "        )\n",
       "        (1): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (2): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (3): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (4): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (5): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(256, eps=0.0)\n",
       "          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(1024, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (layer4): Sequential(\n",
       "        (0): Bottleneck(\n",
       "          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "          (downsample): Sequential(\n",
       "            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "            (1): FrozenBatchNorm2d(2048, eps=0.0)\n",
       "          )\n",
       "        )\n",
       "        (1): Bottleneck(\n",
       "          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "        (2): Bottleneck(\n",
       "          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn2): FrozenBatchNorm2d(512, eps=0.0)\n",
       "          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): FrozenBatchNorm2d(2048, eps=0.0)\n",
       "          (relu): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (fpn): FeaturePyramidNetwork(\n",
       "      (inner_blocks): ModuleList(\n",
       "        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (2): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (3): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (layer_blocks): ModuleList(\n",
       "        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      )\n",
       "      (extra_blocks): LastLevelMaxPool()\n",
       "    )\n",
       "  )\n",
       "  (rpn): RegionProposalNetwork(\n",
       "    (anchor_generator): AnchorGenerator()\n",
       "    (head): RPNHead(\n",
       "      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (roi_heads): RoIHeads(\n",
       "    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)\n",
       "    (box_head): TwoMLPHead(\n",
       "      (fc6): Linear(in_features=12544, out_features=1024, bias=True)\n",
       "      (fc7): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "    )\n",
       "    (box_predictor): FastRCNNPredictor(\n",
       "      (cls_score): Linear(in_features=1024, out_features=91, bias=True)\n",
       "      (bbox_pred): Linear(in_features=1024, out_features=364, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_datas = []\n",
    "test_dir = f'{data_dir}/test_images'\n",
    "for p in get_file_paths(test_dir):\n",
    "    img = cv2.imread(p) / 255\n",
    "    img = cv2.resize(img, (256, 256))\n",
    "    img = torch.tensor(img).permute(2, 0, 1).to(torch.float32)\n",
    "    test_datas.append(img)\n",
    "prediction = model(test_datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),\n",
       "  'labels': tensor([], dtype=torch.int64),\n",
       "  'scores': tensor([], grad_fn=<IndexBackward0>)},\n",
       " {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),\n",
       "  'labels': tensor([], dtype=torch.int64),\n",
       "  'scores': tensor([], grad_fn=<IndexBackward0>)},\n",
       " {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),\n",
       "  'labels': tensor([], dtype=torch.int64),\n",
       "  'scores': tensor([], grad_fn=<IndexBackward0>)},\n",
       " {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),\n",
       "  'labels': tensor([], dtype=torch.int64),\n",
       "  'scores': tensor([], grad_fn=<IndexBackward0>)}]"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "expected scalar type Double but found Float",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-17-d4cb46176a6c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m output = model(imgs.to(torch.float64), [{'boxes': boxes[i][None], 'labels': labels[i][None]} \n\u001b[1;32m----> 2\u001b[1;33m                       for i in range(batch_size)])\n\u001b[0m",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torchvision\\models\\detection\\generalized_rcnn.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, images, targets)\u001b[0m\n\u001b[0;32m     91\u001b[0m                                      .format(degen_bb, target_idx))\n\u001b[0;32m     92\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 93\u001b[1;33m         \u001b[0mfeatures\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackbone\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensors\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     94\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     95\u001b[0m             \u001b[0mfeatures\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'0'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torchvision\\models\\detection\\backbone_utils.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     42\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbody\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfpn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torchvision\\models\\_utils.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     60\u001b[0m         \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 62\u001b[1;33m             \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     63\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreturn_layers\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m                 \u001b[0mout_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreturn_layers\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    444\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    445\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 446\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    447\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    448\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Users\\EDY\\anaconda3\\envs\\faceparsing\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[1;34m(self, input, weight, bias)\u001b[0m\n\u001b[0;32m    441\u001b[0m                             _pair(0), self.dilation, self.groups)\n\u001b[0;32m    442\u001b[0m         return F.conv2d(input, weight, bias, self.stride,\n\u001b[1;32m--> 443\u001b[1;33m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[0;32m    444\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    445\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: expected scalar type Double but found Float"
     ]
    }
   ],
   "source": [
    "\n",
    "output = model(imgs.to(torch.float32), [{'boxes': boxes[i][None], 'labels': labels[i][None]} \n",
    "                      for i in range(batch_size)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss_classifier': tensor(0.0464, grad_fn=<NllLossBackward0>),\n",
       " 'loss_box_reg': tensor(0.0400, grad_fn=<DivBackward0>),\n",
       " 'loss_objectness': tensor(0.0060, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),\n",
       " 'loss_rpn_box_reg': tensor(0.0120, grad_fn=<DivBackward0>)}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  0, 255,   0, 255],\n",
       "        [  8, 255,   0, 255]])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 4])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "boxes[:, None, :].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[184, 184, 166],\n",
       "        [191, 189, 171],\n",
       "        [196, 190, 171],\n",
       "        ...,\n",
       "        [250, 247, 233],\n",
       "        [253, 250, 236],\n",
       "        [255, 252, 238]],\n",
       "\n",
       "       [[183, 183, 167],\n",
       "        [188, 185, 170],\n",
       "        [192, 187, 172],\n",
       "        ...,\n",
       "        [250, 247, 233],\n",
       "        [253, 250, 236],\n",
       "        [255, 253, 239]],\n",
       "\n",
       "       [[188, 187, 177],\n",
       "        [191, 189, 179],\n",
       "        [192, 187, 178],\n",
       "        ...,\n",
       "        [251, 248, 234],\n",
       "        [254, 251, 237],\n",
       "        [255, 254, 240]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[134, 146, 156],\n",
       "        [134, 146, 156],\n",
       "        [134, 146, 156],\n",
       "        ...,\n",
       "        [255, 255, 255],\n",
       "        [255, 255, 255],\n",
       "        [255, 253, 252]],\n",
       "\n",
       "       [[134, 146, 156],\n",
       "        [134, 146, 156],\n",
       "        [134, 146, 156],\n",
       "        ...,\n",
       "        [255, 255, 254],\n",
       "        [255, 255, 254],\n",
       "        [254, 255, 251]],\n",
       "\n",
       "       [[135, 147, 157],\n",
       "        [134, 146, 156],\n",
       "        [134, 146, 156],\n",
       "        ...,\n",
       "        [254, 255, 253],\n",
       "        [254, 255, 251],\n",
       "        [254, 255, 251]]], dtype=uint8)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = cv2.imread('F:/Work/aeye/train/gen_train/fake_PhoneT0000004_3.jpg')\n",
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.7216, 0.7216, 0.6510],\n",
       "         [0.7490, 0.7412, 0.6706],\n",
       "         [0.7686, 0.7451, 0.6706],\n",
       "         ...,\n",
       "         [0.9804, 0.9686, 0.9137],\n",
       "         [0.9922, 0.9804, 0.9255],\n",
       "         [1.0000, 0.9882, 0.9333]],\n",
       "\n",
       "        [[0.7176, 0.7176, 0.6549],\n",
       "         [0.7373, 0.7255, 0.6667],\n",
       "         [0.7529, 0.7333, 0.6745],\n",
       "         ...,\n",
       "         [0.9804, 0.9686, 0.9137],\n",
       "         [0.9922, 0.9804, 0.9255],\n",
       "         [1.0000, 0.9922, 0.9373]],\n",
       "\n",
       "        [[0.7373, 0.7333, 0.6941],\n",
       "         [0.7490, 0.7412, 0.7020],\n",
       "         [0.7529, 0.7333, 0.6980],\n",
       "         ...,\n",
       "         [0.9843, 0.9725, 0.9176],\n",
       "         [0.9961, 0.9843, 0.9294],\n",
       "         [1.0000, 0.9961, 0.9412]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[0.5255, 0.5725, 0.6118],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         ...,\n",
       "         [1.0000, 1.0000, 1.0000],\n",
       "         [1.0000, 1.0000, 1.0000],\n",
       "         [1.0000, 0.9922, 0.9882]],\n",
       "\n",
       "        [[0.5255, 0.5725, 0.6118],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         ...,\n",
       "         [1.0000, 1.0000, 0.9961],\n",
       "         [1.0000, 1.0000, 0.9961],\n",
       "         [0.9961, 1.0000, 0.9843]],\n",
       "\n",
       "        [[0.5294, 0.5765, 0.6157],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         [0.5255, 0.5725, 0.6118],\n",
       "         ...,\n",
       "         [0.9961, 1.0000, 0.9922],\n",
       "         [0.9961, 1.0000, 0.9843],\n",
       "         [0.9961, 1.0000, 0.9843]]], dtype=torch.float64)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(img/255)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(83,\n",
       " ['backbone.body.conv1.weight',\n",
       "  'backbone.body.layer1.0.conv1.weight',\n",
       "  'backbone.body.layer1.0.conv2.weight',\n",
       "  'backbone.body.layer1.0.conv3.weight',\n",
       "  'backbone.body.layer1.0.downsample.0.weight',\n",
       "  'backbone.body.layer1.1.conv1.weight',\n",
       "  'backbone.body.layer1.1.conv2.weight',\n",
       "  'backbone.body.layer1.1.conv3.weight',\n",
       "  'backbone.body.layer1.2.conv1.weight',\n",
       "  'backbone.body.layer1.2.conv2.weight',\n",
       "  'backbone.body.layer1.2.conv3.weight',\n",
       "  'backbone.body.layer2.0.conv1.weight',\n",
       "  'backbone.body.layer2.0.conv2.weight',\n",
       "  'backbone.body.layer2.0.conv3.weight',\n",
       "  'backbone.body.layer2.0.downsample.0.weight',\n",
       "  'backbone.body.layer2.1.conv1.weight',\n",
       "  'backbone.body.layer2.1.conv2.weight',\n",
       "  'backbone.body.layer2.1.conv3.weight',\n",
       "  'backbone.body.layer2.2.conv1.weight',\n",
       "  'backbone.body.layer2.2.conv2.weight',\n",
       "  'backbone.body.layer2.2.conv3.weight',\n",
       "  'backbone.body.layer2.3.conv1.weight',\n",
       "  'backbone.body.layer2.3.conv2.weight',\n",
       "  'backbone.body.layer2.3.conv3.weight',\n",
       "  'backbone.body.layer3.0.conv1.weight',\n",
       "  'backbone.body.layer3.0.conv2.weight',\n",
       "  'backbone.body.layer3.0.conv3.weight',\n",
       "  'backbone.body.layer3.0.downsample.0.weight',\n",
       "  'backbone.body.layer3.1.conv1.weight',\n",
       "  'backbone.body.layer3.1.conv2.weight',\n",
       "  'backbone.body.layer3.1.conv3.weight',\n",
       "  'backbone.body.layer3.2.conv1.weight',\n",
       "  'backbone.body.layer3.2.conv2.weight',\n",
       "  'backbone.body.layer3.2.conv3.weight',\n",
       "  'backbone.body.layer3.3.conv1.weight',\n",
       "  'backbone.body.layer3.3.conv2.weight',\n",
       "  'backbone.body.layer3.3.conv3.weight',\n",
       "  'backbone.body.layer3.4.conv1.weight',\n",
       "  'backbone.body.layer3.4.conv2.weight',\n",
       "  'backbone.body.layer3.4.conv3.weight',\n",
       "  'backbone.body.layer3.5.conv1.weight',\n",
       "  'backbone.body.layer3.5.conv2.weight',\n",
       "  'backbone.body.layer3.5.conv3.weight',\n",
       "  'backbone.body.layer4.0.conv1.weight',\n",
       "  'backbone.body.layer4.0.conv2.weight',\n",
       "  'backbone.body.layer4.0.conv3.weight',\n",
       "  'backbone.body.layer4.0.downsample.0.weight',\n",
       "  'backbone.body.layer4.1.conv1.weight',\n",
       "  'backbone.body.layer4.1.conv2.weight',\n",
       "  'backbone.body.layer4.1.conv3.weight',\n",
       "  'backbone.body.layer4.2.conv1.weight',\n",
       "  'backbone.body.layer4.2.conv2.weight',\n",
       "  'backbone.body.layer4.2.conv3.weight',\n",
       "  'backbone.fpn.inner_blocks.0.weight',\n",
       "  'backbone.fpn.inner_blocks.0.bias',\n",
       "  'backbone.fpn.inner_blocks.1.weight',\n",
       "  'backbone.fpn.inner_blocks.1.bias',\n",
       "  'backbone.fpn.inner_blocks.2.weight',\n",
       "  'backbone.fpn.inner_blocks.2.bias',\n",
       "  'backbone.fpn.inner_blocks.3.weight',\n",
       "  'backbone.fpn.inner_blocks.3.bias',\n",
       "  'backbone.fpn.layer_blocks.0.weight',\n",
       "  'backbone.fpn.layer_blocks.0.bias',\n",
       "  'backbone.fpn.layer_blocks.1.weight',\n",
       "  'backbone.fpn.layer_blocks.1.bias',\n",
       "  'backbone.fpn.layer_blocks.2.weight',\n",
       "  'backbone.fpn.layer_blocks.2.bias',\n",
       "  'backbone.fpn.layer_blocks.3.weight',\n",
       "  'backbone.fpn.layer_blocks.3.bias',\n",
       "  'rpn.head.conv.weight',\n",
       "  'rpn.head.conv.bias',\n",
       "  'rpn.head.cls_logits.weight',\n",
       "  'rpn.head.cls_logits.bias',\n",
       "  'rpn.head.bbox_pred.weight',\n",
       "  'rpn.head.bbox_pred.bias',\n",
       "  'roi_heads.box_head.fc6.weight',\n",
       "  'roi_heads.box_head.fc6.bias',\n",
       "  'roi_heads.box_head.fc7.weight',\n",
       "  'roi_heads.box_head.fc7.bias',\n",
       "  'roi_heads.box_predictor.cls_score.weight',\n",
       "  'roi_heads.box_predictor.cls_score.bias',\n",
       "  'roi_heads.box_predictor.bbox_pred.weight',\n",
       "  'roi_heads.box_predictor.bbox_pred.bias'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "names =[name for name, par in model.named_parameters()]\n",
    "len(names), names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 0.1098,  1.0315, -0.1079, -0.7713],\n",
       "          [ 0.5492,  0.8812,  0.3834,  0.5156],\n",
       "          [-1.7445,  1.0657, -0.6064, -2.1251]],\n",
       " \n",
       "         [[ 0.3617,  0.1533, -1.5400,  0.0947],\n",
       "          [ 0.4430,  0.1719,  2.0696, -1.4195],\n",
       "          [-0.1677, -1.1848,  0.6109,  0.1091]]]),\n",
       " torch.Size([2, 3, 4]))"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn(2,3,4)\n",
    "a, a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 2, 3])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.permute(2, 0, 1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 3])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.transpose(-1, -2).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.3313, 1.0156,    nan,    nan],\n",
       "         [0.7411, 0.9387, 0.6192, 0.7180],\n",
       "         [   nan, 1.0323,    nan,    nan]],\n",
       "\n",
       "        [[0.6014, 0.3916,    nan, 0.3078],\n",
       "         [0.6656, 0.4147, 1.4386,    nan],\n",
       "         [   nan,    nan, 0.7816, 0.3303]]])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.sqrt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  1., -1., -1.],\n",
       "         [ 1.,  1.,  1.,  1.],\n",
       "         [-1.,  1., -1., -1.]],\n",
       "\n",
       "        [[ 1.,  1., -1.,  1.],\n",
       "         [ 1.,  1.,  1., -1.],\n",
       "         [-1., -1.,  1.,  1.]]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.sign()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.3313,  1.0156, -0.3285, -0.8782],\n",
       "         [ 0.7411,  0.9387,  0.6192,  0.7180],\n",
       "         [-1.3208,  1.0323, -0.7787, -1.4578]],\n",
       "\n",
       "        [[ 0.6014,  0.3916, -1.2410,  0.3078],\n",
       "         [ 0.6656,  0.4147,  1.4386, -1.1914],\n",
       "         [-0.4095, -1.0885,  0.7816,  0.3303]]])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.abs().sqrt() * a.sign()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "09f402a5df0ab33181308beb105cd67fedce8b63239961f335319bf6515a38ef"
  },
  "kernelspec": {
   "display_name": "Python 3.6.5 ('ps_torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
